import torch
import torch.nn as nn
from warnings import warn
from torch import cos, sin, sign, norm
from .template import ControlledSystemTemplate


class CartPole(ControlledSystemTemplate):
    '''Continuous version of the OpenAI Gym cartpole
    Inspired by: https://gist.github.com/iandanforth/e3ffb67cf3623153e968f2afdfb01dc8'''
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)        
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = (self.masspole + self.masscart)
        self.length = 0.5
        self.polemass_length = (self.masspole * self.length)
        
    def _dynamics(self, t, x_):
        self.nfe += 1 # increment number of function evaluations
        u = self._evaluate_controller(t, x_) # controller
        
        # States
        x   = x_[..., 0:1]
        dx  = x_[..., 1:2]
        θ   = x_[..., 2:3]
        dθ  = x_[..., 3:4]
        
        # Auxiliary variables
        cosθ, sinθ = cos(θ), sin(θ)
        temp = (u + self.polemass_length * dθ**2 * sinθ) / self.total_mass
        
        # Differential Equations
        ddθ = (self.gravity * sinθ - cosθ * temp) / \
                (self.length * (4.0/3.0 - self.masspole * cosθ**2 / self.total_mass))
        ddx = temp - self.polemass_length * ddθ * cosθ / self.total_mass
        self.cur_f = torch.cat([dx, ddx, dθ, ddθ], -1)
        return self.cur_f

    def render(self):
        raise NotImplementedError("TODO: add the rendering from OpenAI Gym")
